import os
import sys
import pandas as pd
import matplotlib.pyplot as plt


def f(filename, k):
    df = pd.read_csv(filename)
    print(df.index)
    grouped = df.groupby(df.index // k)

    # Print the grouped indices using the Pandas API
    #for group_name, group_data in grouped.groups.items():
    #    print(f"Group {group_name}: Indices {group_data}")
    mean_std_df = grouped['Real Det Return'].agg(['mean', 'std'])
    for index, row in mean_std_df.iterrows():
        print("$%.2f\\pm%.2f$" % (row['mean'], row['std']))
    return mean_std_df

def plot_mean_std(k):
    for i, exp in enumerate(['Ant', 'HalfCheetah', 'Hopper', 'Walker2d']):
        filename_firl = None
        filename_pagar = None

        plt.figure(figsize=(13, 9))

        
        if exp == 'Hopper':
            filename_firl = os.path.join(os.path.dirname(__file__), f"logs/HopperFH-v0/exp-4/fkl/2024_08_11_05_54_52/progress.csv")
            filename_pagar = os.path.join(os.path.dirname(__file__), f"logs/HopperFH-v0/exp-4/pagar_fkl/2024_08_11_06_05_53/progress.csv")
        elif exp == 'Ant':
            filename_firl = os.path.join(os.path.dirname(__file__), f"logs/AntFH-v0/exp-16/fkl/2024_08_10_23_06_12/progress.csv")
            filename_pagar = os.path.join(os.path.dirname(__file__), f"logs/AntFH-v0/exp-16/pagar_fkl/2024_08_11_05_32_11/progress.csv")
            plt.xlim(0, 2e6)
            plt.ylim(-100, 6000)
        
        elif exp == 'HalfCheetah':
            filename_firl = os.path.join(os.path.dirname(__file__), f"logs/HalfCheetahFH-v0/exp-16/fkl/2024_08_11_05_56_06/progress.csv")
            filename_pagar = os.path.join(os.path.dirname(__file__), f"logs/HalfCheetahFH-v0/exp-16/pagar_fkl/2024_08_11_06_02_54/progress.csv")
        elif exp == 'Walker2d':
            filename_firl = os.path.join(os.path.dirname(__file__), f"logs/Walker2dFH-v0/exp-16/fkl/2024_08_11_05_54_32/progress.csv")
            filename_pagar = os.path.join(os.path.dirname(__file__), f"logs/Walker2dFH-v0/exp-16/pagar_fkl/2024_08_11_06_16_17/progress.csv")
            plt.xlim(0, 3e6)
            plt.ylim(-100, 6000)
        

        df_firl = pd.read_csv(filename_firl)
        df_pagar = pd.read_csv(filename_pagar)

        grouped_firl = df_firl.groupby(df_firl.index // k)
        grouped_pagar = df_pagar.groupby(df_pagar.index // k)

        mean_std_df_firl_step = grouped_firl['Running Env Steps'].agg(['mean'])
        mean_std_df_pagar_step = grouped_pagar['Running Env Steps'].agg(['mean'])

        mean_std_df_firl = grouped_firl['Real Det Return'].agg(['mean', 'std'])
        mean_std_df_pagar = grouped_pagar['Real Det Return'].agg(['mean', 'std'])

        
        plt.plot(mean_std_df_pagar_step['mean'], mean_std_df_pagar['mean'], label='PAGAR-fIRL', color='#ff7f0e')
        plt.fill_between(mean_std_df_pagar_step['mean'], 
                 mean_std_df_pagar['mean'] - mean_std_df_pagar['std'], 
                 mean_std_df_pagar['mean'] + mean_std_df_pagar['std'], 
                 color='#ff7f0e', alpha=0.2)
        plt.plot(mean_std_df_firl_step['mean'], mean_std_df_firl['mean'], label='fIRL', color='#1f77b4')
        plt.fill_between(mean_std_df_firl_step['mean'], 
                 mean_std_df_firl['mean'] - mean_std_df_firl['std'], 
                 mean_std_df_firl['mean'] + mean_std_df_firl['std'], 
                 color='#1f77b4', alpha=0.2)
        plt.grid(True)
        if i == 0:
            plt.legend(fontsize=40)
            plt.xlabel('Steps', fontsize=30)
            plt.ylabel('Average Return', fontsize=30)
        
        plt.gca().xaxis.get_offset_text().set_fontsize(30)
        plt.xticks(fontsize=30)
        plt.yticks(fontsize=30)

        plt.savefig(''.join([exp, '_pagar_firl.png']))
         



if __name__ == "__main__":
    #filename = sys.argv[1] 
    #k = int(sys.argv[2])
    #print(filename, k)
    #f(filename, k)
    plot_mean_std(10)

